from torch.utils.data import Dataset, DataLoader
from sft_train import seed_everything
import argparse
import os
import json
import torch
from tqdm import tqdm
from datetime import datetime
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR
from transformers import get_cosine_schedule_with_warmup, AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model, TaskType
from transformers.trainer_pt_utils import LabelSmoother
import wandb
from inference import inference, GSM8kTestDataset
from pebble import ProcessPool
from concurrent.futures import TimeoutError as FutureTimeoutError
from multiprocessing import Process, Manager
from grader import math_equal_process
from parser_utils import extract_answer
import random
import numpy as np
from eval import eval_file

IGNORE_TOKEN_ID = LabelSmoother.ignore_index

class GSM8kChatDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        question = self.data[idx]["question"]

        instruction = self.tokenizer.apply_chat_template(
            [
                {"role": "system", "content": "Please solve the following problem step by step."},
                {"role": "user", "content": question},
            ],
            tokenize = False,
            add_generation_prompt=True,
        )
        return instruction

class GSM8kInstructDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=1024, device='cuda:0'):
        self.tokenizer = tokenizer
        self.data = data
        self.max_length = max_length
        self.device = device

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        example = self.data[idx]
        question = example["question"]
        answer = example["answer"]


        instruction = self.tokenizer.apply_chat_template(
            [
                {"role": "system", "content": "Please solve the following problem step by step."},
                {"role": "user", "content": question},
            ],
            tokenize = False,
            add_generation_prompt=True,
        )
        instruction = self.tokenizer(instruction, add_special_tokens=False)

        response = self.tokenizer(answer,  add_special_tokens=False)

        # Combine the input IDs from the instruction and response, and append a padding token.
        input_ids = instruction["input_ids"] + response["input_ids"] + [self.tokenizer.pad_token_id]
        # print(input_ids)
        attention_mask = (
        instruction["attention_mask"] + response["attention_mask"] + [1]
        )

        # Create labels for the model. Mask the instruction part with -100 (ignored during loss calculation).
        labels = [-100] * len(instruction["input_ids"]) + response["input_ids"] + [self.tokenizer.pad_token_id]


        start_idx = len(instruction["input_ids"])
        seq_len = len(input_ids)
        
        return {
            "input_ids":input_ids,
            "labels": labels,
            'start_idx': start_idx,
            'seq_len': seq_len,
            'attention_mask': attention_mask,
            # 'rewards': example['rewards'],
        }

class GSM8kBaseDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=1024, device='cuda:0'):
        self.tokenizer = tokenizer
        self.data = data
        # self.max_length = max_length
        self.device = device

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        example = self.data[idx]
        question = example["question"]
        answer = example["answer"]

        # Concatenate question and answer
        full_text = f"{question} {answer}"
        encoding = self.tokenizer(
            full_text,
            truncation=True,
            # max_length=self.max_length,
            # add_special_tokens=True
        )
        input_ids = encoding["input_ids"]+[self.tokenizer.eos_token_id]
        # labels = encoding["input_ids"][1:] + [self.tokenizer.eos_token_id] # remove the first token
        labels = encoding["input_ids"]+[self.tokenizer.eos_token_id]

        attention_mask = encoding['attention_mask'] + [1]
        start_idx = example['start_idx'] # index of '?'
        labels[:(start_idx+1)] = [IGNORE_TOKEN_ID] * (start_idx+1)
        seq_len = len(input_ids)
        return {
            "input_ids": input_ids,
            "labels": labels,
            'start_idx': start_idx,
            'seq_len': seq_len,
            'attention_mask': attention_mask,
            # 'rewards': example['rewards'],
        }

class GSM8kBaseChatDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        question = self.data[idx]["question"]
        return question

def collate_fn(batch, tokenizer):
    input_ids = [torch.tensor(x["input_ids"], dtype=torch.long) for x in batch]
    labels = [torch.tensor(x["labels"], dtype=torch.long) for x in batch]
    attention_mask = [torch.tensor(x["attention_mask"], dtype=torch.long) for x in batch]
    # Pad sequences
    padded_inputs = torch.nn.utils.rnn.pad_sequence(
        input_ids, batch_first=True, padding_value=tokenizer.pad_token_id
    )
    padded_labels = torch.nn.utils.rnn.pad_sequence(
        labels, batch_first=True, padding_value=IGNORE_TOKEN_ID
    )
    attention_mask = torch.nn.utils.rnn.pad_sequence(
        attention_mask, batch_first=True, padding_value=0
    )
    return {
        "input_ids": padded_inputs,
        "attention_mask": attention_mask,
        "labels": padded_labels
    }
    
def collate_fn_eval(batch, tokenizer):
    return tokenizer(
        batch,
        padding=True,
        truncation=True,
        max_length=1024,  # 控制输入长度
        return_tensors="pt"
    )
# Define the warmup function
def warmup_lr_scheduler(step):
    if step < warmup_steps:
        return float(step) / float(max(1, warmup_steps))  # Linear warmup
    else:
        return 1.0  # No warmup, return to normal schedule

def eval_model(model, base_model_name, data, device, is_instruct, batch_size=16, path=''):
    model.eval()
    # dataset = GSM8kTestDataset(data)
    tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True, padding_side='left')
    if is_instruct:
        dataset = GSM8kChatDataset(data, tokenizer)
    else:
        dataset = GSM8kBaseChatDataset(data, tokenizer)
    dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=lambda x: collate_fn_eval(x, tokenizer))

    generation_config = {
        "max_new_tokens": 2048,          
        "do_sample": False,
        "pad_token_id": tokenizer.eos_token_id,
        "use_cache": True               # 
    }

    inference_results = []
    for batch in tqdm(dataloader):
        inputs = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            # print(inputs)
            outputs = model.generate(**inputs, **generation_config)
            batch_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
            inference_results.extend(batch_texts)
    # path = f'./models/{model_name}/inference_results-greedy.jsonl'
    with open(path, 'w') as f:
        for text, item in zip(inference_results, data):
            item["generated"] = text
            f.write(json.dumps(item) + '\n')

    scores = eval_file(path)
    return scores



if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--base_model', type=str, default='Qwen/Qwen2.5-1.5B-Instruct')
    parser.add_argument('--data_path', type=str)
    parser.add_argument('--num_epochs', type=int, default=1)
    parser.add_argument('--seed', type=int, default=2025)
    parser.add_argument('--batch_size', type=int, default=4)
    parser.add_argument('--use_scheduler', action='store_true')
    parser.add_argument('--evaluation_interval', type=int, default=1000)

    parser.add_argument('--r', type=int, default=16)
    parser.add_argument('--lora_alpha', type=int, default=64)
    parser.add_argument('--lora_dropout', type=float, default=0.1)
    parser.add_argument('--lr', type=float, default=1e-6)

    parser.add_argument('--target_modules', type=str, nargs='+', default=['q_proj', 'k_proj', 'v_proj', 'o_proj'])
    parser.add_argument('--modules_to_save', type=str, nargs='+', default=['wte', 'lm_head'])
    parser.add_argument('--store_model', action='store_true')
    parser.add_argument('--full_param', action='store_true')
    parser.add_argument('--mix', type=float, default=0)
    parser.add_argument('--gradient_accumulation_steps', type=int, default=1)
    # parser.add_argument('--warmup_steps', type=int, default=500)
    args = parser.parse_args()
    print(args)
    # -------------------------------------------------------
    # 1. Specify base model and dataset path
    # -------------------------------------------------------
    model_name = args.base_model
    mix = args.mix
    if 'Instruct' in model_name:
        is_instruct = True
    else:
        is_instruct = False
    print('Is Instruct:', is_instruct)
    if mix == 0:
        data_path = os.path.join('./math-shepherd', args.data_path)
        with open(data_path, 'r') as f:
            raw_data = [json.loads(d) for d in f.readlines()]
    else:
        # expert_data_path = './math-shepherd/expert_data_sparse_threshold_5_new.jsonl'
        # imperfect_data_path = './math-shepherd/imperfect_data_sparse_threshold_5_new.jsonl'
        # expert_data_path = './math-shepherd/expert_data_threshold_5_0328.jsonl'
        expert_data_path = './math-shepherd/expert_data_50000.jsonl'
        imperfect_data_path = './math-shepherd/imperfect_data_50000.jsonl'
        with open(expert_data_path, 'r') as f:
            expert_data = [json.loads(d) for d in f.readlines()]
        with open(imperfect_data_path, 'r') as f:
            imperfect_data = [json.loads(d) for d in f.readlines()]
        expert_data_num = int(len(expert_data) * mix)
        imperfect_data_num = int(len(imperfect_data) * (1 - mix))
        raw_data = expert_data[:expert_data_num] + imperfect_data[:imperfect_data_num]


    # -------------------------------------------------------
    # 2. Setup tokenizer (with trust_remote_code if needed)
    # -------------------------------------------------------
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token  # For some models like Llama/Mistral

    # -------------------------------------------------------
    # 3. Load base model, then wrap with LoRA
    # -------------------------------------------------------
    # Load the base Causal LM
    base_model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        device_map="auto"  # optional, uses GPU if available
    )
    # Enable gradient computation for the model's input embeddings.
    seed = args.seed
    r = args.r
    lora_alpha = args.lora_alpha
    lora_dropout = args.lora_dropout
    # target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"]
    # target_modules = ['q_proj', 'v_proj'] # original

    # target_modules = ['q_proj', 'k_proj']
    target_modules = args.target_modules

    # target_modules = ['q_proj', 'k_proj', 'v_proj']
    modules_to_save = args.modules_to_save

    # Define your LoRA configuration
    lora_config = LoraConfig(
        r=r,                # rank
        lora_alpha=lora_alpha,        # alpha scaling
        lora_dropout=lora_dropout,    # dropout
        bias="none",
        target_modules=target_modules,
        task_type=TaskType.CAUSAL_LM,
        modules_to_save=modules_to_save
    )

    if args.full_param:
        model = base_model
    else:
        # Wrap the base model with LoRA
        model = get_peft_model(base_model, lora_config)
        model.print_trainable_parameters()  # Show how many params will be trained

    # -------------------------------------------------------
    # 4. Prepare dataset/dataloader
    # -------------------------------------------------------
    if is_instruct:
        dataset = GSM8kInstructDataset(raw_data, tokenizer)
    else:
        dataset = GSM8kBaseDataset(raw_data, tokenizer)
    batch_size = args.batch_size
    num_epochs = args.num_epochs
    accumulation_steps = args.gradient_accumulation_steps
    train_loader = DataLoader(
        dataset, 
        batch_size=batch_size,
        shuffle=True,
        collate_fn=lambda x: collate_fn(x, tokenizer)
    )


    test_data_path = './gsm8k/test.jsonl'
    eval_data_path = './gsm8k/train.jsonl'
    with open(eval_data_path, 'r') as f:
        total_eval_data = [json.loads(d) for d in f.readlines()]
    with open(test_data_path, 'r') as f:
        total_test_data = [json.loads(d) for d in f.readlines()]
    # eval_data = total_eval_data[:100]
    # test_data = total_test_data[:400]
    # -------------------------------------------------------
    # 5. Initialize W&B + Optimizer
    # -------------------------------------------------------
    lr = args.lr
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    # Create the CosineAnnealingLR scheduler
    # cosine_scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs * len(train_loader), eta_min=lr/5)

    # # Combine warmup and cosine annealing using LambdaLR
    # scheduler = LambdaLR(optimizer, lr_lambda=warmup_lr_scheduler)
    num_training_steps = num_epochs * len(train_loader) // accumulation_steps
    warmup_steps = int(num_training_steps * 0.03)
    use_scheduler = args.use_scheduler
    scheduler = get_cosine_schedule_with_warmup(
        optimizer, num_warmup_steps=warmup_steps, num_training_steps=num_training_steps
    )

    current_time = datetime.now().strftime("%m%d-%H%M%S")
    module_list = [m.replace('_proj', '') for m in target_modules]
    modules = '-'.join(module_list)
    trainable = '-'.join(modules_to_save) if modules_to_save is not None else ''

    data_type = args.data_path.split('/')[-1].split('.')[0].replace('_data_sparse', '')
    if args.full_param:
        if args.mix == 0:
            run_name = f"{model_name}-seed{seed}-Full-SFT-lr{lr}-{data_type}-{current_time}"
        else:
            run_name = f"{model_name}-seed{seed}-Full-SFT-lr{lr}-mix{args.mix}-{current_time}"
    else:
        if args.mix == 0:
            run_name = f"{model_name}-template-LoRA-seed{seed}-SFT-b{batch_size}-acc{accumulation_steps}-r{r}-alp{lora_alpha}-lr{lr}-{modules}-{trainable}-{data_type}-{current_time}"
        else:
            run_name = f"{model_name}-template-LoRA-seed{seed}-SFT-b{batch_size}-acc{accumulation_steps}-r{r}-alp{lora_alpha}-lr{lr}-mix{args.mix}-{modules}-{trainable}-{current_time}"

    if args.mix == 0:
        wandb.init(entity='Your wandb entity', project=f"MathRL-{data_type}", name=run_name, reinit=True)
    else:
        wandb.init(entity='Your wandb entity', project=f"MathRL-mix{args.mix}", name=run_name, reinit=True)

    wandb.config.update({
        'data_type': data_type,
        "learning_rate": lr,
        "batch_size": batch_size,
        "lora_dropout": lora_dropout,
        "lora_alpha": lora_alpha,
        "r": r,
        "target_modules": target_modules,
        "seed": seed
    })
    
    # -------------------------------------------------------
    # 6. Training Loop
    # -------------------------------------------------------
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model.train()

    num_epochs = args.num_epochs # default 1
    global_step = 0

    actual_step = 0 # consider the gradient accumulation steps

    evaluation_interval = args.evaluation_interval # default 1000

    best_accuracy = 0
    # make dictionary
    if args.store_model:    
        if not os.path.exists(f'./models/{run_name}'):
            os.makedirs(f'./models/{run_name}')

    if is_instruct:
        eval_accuracy = test_accuracy = 0.695982
    else:
        eval_accuracy = test_accuracy = 0.677786

    wandb.log({
        "eval_accuracy": eval_accuracy,
        "test_accuracy": test_accuracy
    })

    optimizer.zero_grad()

    for epoch in range(num_epochs):
            
        for step, batch in enumerate(tqdm(train_loader)):
                # if global_step % evaluation_interval == 0:       
            # if global_step <= 2000 and global_step % 500 == 0 or global_step > 2000 and global_step % 1000 == 0:
            if actual_step % 100 == 0 and global_step % accumulation_steps == 0:
            # if actual_step ==500 or actual_step ==600 and global_step % actual_step == 0:
                if args.store_model and actual_step != 0:
                    model.save_pretrained(f"./models/{run_name}/ckpt-{actual_step}", safe_serialization=False)
                    tokenizer.save_pretrained(f"./models/{run_name}/ckpt-{actual_step}", safe_serialization=False)

                if actual_step != 0:
            # if global_step <= 1000 and global_step % 500 == 0 or 1000<global_step<=2000 and global_step % 200 == 0 or global_step>2000 and global_step % 1000 == 0:
                # inference(model_path=f"./models/{run_name}", base_model_name=model_name) # no sample
                # inference(model_path=f"./models/{run_name}", base_model_name=model_name, do_sample=True, temperature=0.7) # sample
                    # scores = eval_model(model, tokenizer, eval_data+test_data, device
                    path = f'./models/{run_name}/ckpt-{actual_step}/inference_results-greedy.jsonl'
                    scores = eval_model(model, base_model_name=model_name, data=total_test_data, device=device, path=path, is_instruct=is_instruct)
                    # # eval_scores = scores[:100]
                    # # test_scores = scores[100:]
                    eval_scores = scores
                    test_scores = scores
                    eval_accuracy = sum(eval_scores) / len(eval_scores)
                    test_accuracy = sum(test_scores) / len(test_scores)

                wandb.log({
                    "eval_accuracy": eval_accuracy,
                    "test_accuracy": test_accuracy
                })
                print(f"Eval Accuracy at step {actual_step} is {eval_accuracy}")
                print(f"Test Accuracy at step {actual_step} is {test_accuracy}")
                if test_accuracy > best_accuracy:
                    best_accuracy = test_accuracy
                    print('get best accuracy!')
                    if args.store_model:
                        model.save_pretrained(f"./models/{run_name}/best")
                        tokenizer.save_pretrained(f"./models/{run_name}/best")
                model.train()
            global_step += 1
            actual_step = global_step // accumulation_steps
            # # print(f"Global Step: {global_step}, Actual Step: {actual_step}") 
            outputs = model(
                input_ids=batch["input_ids"].to(device),
                attention_mask=batch["attention_mask"].to(device),
                labels=batch["labels"].to(device)
            )
            loss = outputs.loss / accumulation_steps
            loss.backward()

            if (global_step + 1) % accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()
                if use_scheduler:
                    scheduler.step()
            # Log to wandb
                if use_scheduler:   
                    wandb.log({
                        "epoch": epoch,
                        "global_step": global_step,
                        "step": actual_step,
                        "loss": loss.item()*accumulation_steps,
                        "learning_rate": scheduler.get_last_lr()[0]
                    })
                else:
                    wandb.log({
                        "epoch": epoch,
                        "global_step": global_step,
                        "step": actual_step,
                        "loss": loss.item()*accumulation_steps,
                        "learning_rate": lr
                    })

            if global_step % 10 == 0:
                print(f"Epoch {epoch} | Step {step} | Global Step {global_step} | Actual Step {actual_step} | Loss {loss.item()*accumulation_steps}")


            if actual_step > 200:
                break

        # if args.store_model:
        #     model.save_pretrained(f"./models/{run_name}/epoch-{epoch}-step-{actual_step}")
        #     tokenizer.save_pretrained(f"./models/{run_name}/epoch-{epoch}-step-{actual_step}")

        # # eval at the end of each epoch
        # model.eval()
        # path = f'./models/{run_name}/epoch-{epoch}-step-{actual_step}/inference_results-greedy.jsonl'
        # scores = eval_model(model, base_model_name=model_name, data=total_test_data, device=device, path=path, is_instruct=is_instruct)
        # # eval_scores = scores[:100]
        # # test_scores = scores[100:]
        # eval_scores = scores
        # test_scores = scores
        # eval_accuracy = sum(eval_scores) / len(eval_scores)
        # test_accuracy = sum(test_scores) / len(test_scores)
        # wandb.log({
        #     "eval_accuracy": eval_accuracy,
        #     "test_accuracy": test_accuracy
        # })  
        # if test_scores > best_accuracy:
        #     best_accuracy = test_scores
        #     if args.store_model:
        #         model.save_pretrained(f"./models/{run_name}/best")
        #         tokenizer.save_pretrained(f"./models/{run_name}/best")
        # model.train()
        # if global_step == 15000:
        #     break
            # if step % 500 == 0:
            #     # Save model checkpoint
            #     if args.store_model:
            #         model.save_pretrained(f"./models/{run_name}", safe_serialization=False)
            #         tokenizer.save_pretrained(f"./models/{run_name}", safe_serialization=False)
    # -------------------------------------------------------
    # 7. Optionally save your LoRA adapter
    # -------------------------------------------------------
    # This will only save the LoRA weight diffs, not the entire base model.
    # You can merge them later or load them with PEFT at inference.
    # if args.store_model:
    #     model.save_pretrained(f"./models/{run_name}/last", safe_serialization=False)
    #     tokenizer.save_pretrained(f"./models/{run_name}/last", safe_serialization=False)

    wandb.finish()
    
    # -------------------------------------------------------
    # 8. Inference
    # -------------------------------------------------------
    # if args.store_model:
    #     inference(model_path=f"./models/{run_name}/best", base_model_name=model_name, use_template=True) # no sample
        # inference(model_path=f"./models/{run_name}/best", base_model_name=model_name, do_sample=True, temperature=0.7, use_template=True) # sample


    # eval_file(f'{output_dir}.jsonl')
